"""
This code based on codes from https://github.com/tristandeleu/ntm-one-shot \
                              and https://github.com/kjunelee/MetaOptNet
"""
import numpy as np
import random
import pickle as pkl
from functions import *
import torch
import time
import torch.nn.functional as F

class miniImageNetGenerator(object):

    def __init__(self, data_file, nb_classes, num_user, n_spt,
                  n_qry, max_iter=None):
        super(miniImageNetGenerator, self).__init__()
        self.data_file = data_file
        self.nb_classes = nb_classes
        self.max_iter = max_iter
        self.num_iter = 0
        self.data_dict = self._load_data(self.data_file)
        self.num_user = num_user
        self.n_spt = n_spt
        self.n_qry = n_qry * num_user


    def _load_data(self, data_file):
        dataset = self.load_data(data_file)
        data = dataset['data']
        labels = dataset['labels']
        label2ind = self.buildLabelIndex(labels)

        return {key: torch.tensor(data[val]).permute([0,3,1,2]) for (key, val) in label2ind.items()}

    def load_data(self, data_file):
        try:
            with open(data_file, 'rb') as fo:
                data = pkl.load(fo)
            return data
        except:
            with open(data_file, 'rb') as f:
                u = pkl._Unpickler(f)
                u.encoding = 'latin1'
                data = u.load()
            return data

    def buildLabelIndex(self, labels):
        label2inds = {}
        for idx, label in enumerate(labels):
            if label not in label2inds:
                label2inds[label] = []
            label2inds[label].append(idx)

        return label2inds


    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def next(self):
        if (self.max_iter is None) or (self.num_iter < self.max_iter):
            self.num_iter += 1
            x_spt, y_spt, x_qry, y_qry = self.sample(self.nb_classes)

            return (self.num_iter - 1), x_spt, y_spt, x_qry, y_qry
        else:
            raise StopIteration()


    def augment(self, img):

        # random cropping
        npad = (8, 8, 8, 8)
        img = F.pad(img, npad) # [3,84,84]
        x = random.randint(0, 16)
        y = random.randint(0, 16)
        img = img[:,y:y + 84, x:x + 84]

        # random flipping
        # flip_sign = random.randint(1,2)
        # if flip_sign == 1:
        #     img = torch.flip(img, 1) # horizontal flip.

        return img

    def sample(self, nb_classes):

        t_start = time.time()
        key_list = self.data_dict.keys()
        x_spt = torch.zeros(self.num_user, nb_classes * self.n_spt, 3, 84, 84)
        y_spt = torch.zeros(self.num_user, nb_classes * self.n_spt, dtype=int)
        x_qry = torch.zeros(self.n_qry * nb_classes, 3, 84, 84)
        y_qry = torch.zeros(self.n_qry * nb_classes, dtype=int)

        sampled_class = random.sample(key_list, nb_classes)  # randomly choose k class for an episode
        for clsidx, _class in enumerate(sampled_class):
            _imgs = self.data_dict[_class]  # [600,84,84,3]
            all_idx = set([i for i in range(len(_imgs))])
            for user in range(self.num_user):
                _ind1 = random.sample(all_idx, self.n_spt)
                all_idx = all_idx - set(_ind1)
                support_set = _imgs[_ind1]
                for i in range(len(support_set)):
                    support_set[i] = self.augment(support_set[i])

                x_spt[user, self.n_spt * clsidx:self.n_spt * (clsidx + 1)] = support_set
                y_spt[user, self.n_spt * clsidx:self.n_spt * (clsidx + 1)] = clsidx

            _ind2 = random.sample(all_idx, self.n_qry)
            query_set = _imgs[_ind2]
            x_qry[self.n_qry * clsidx:self.n_qry * (clsidx + 1)] = query_set
            y_qry[self.n_qry * clsidx:self.n_qry * (clsidx + 1)] = clsidx

        t_end = time.time()
        elapsed_time = t_end-t_start

        return x_spt, y_spt, x_qry, y_qry


